import os
os.chdir("../")
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import numpy as np
import matplotlib.pyplot as plt

CB91_Blue = '#2CBDFE'
CB91_Green = '#47DBCD'
CB91_Pink = '#F3A0F2'
CB91_Purple = '#9D2EC5'
CB91_Violet = '#661D98'
CB91_Amber = '#F5B14C'
seventh_color = "#4b97ec"
color_list = [CB91_Blue, CB91_Pink, CB91_Green, CB91_Amber, CB91_Purple, CB91_Violet, seventh_color]

RESULT_PATH = "results/thesis_submission_results/"
NUM_RUNS = 4
NUM_EPISODES = 12000
episodes_indices = [i for i in range(0, NUM_EPISODES, 20)]

sender_ttpb_iql = np.load(RESULT_PATH + "iql_sender_time_to_pb_argmax.npy")
for i in range(sender_ttpb_iql.shape[0]):
    running_steps = 20
    for j in range(sender_ttpb_iql.shape[1]):
        running_steps =  0.99 * running_steps + 0.01 * sender_ttpb_iql[i, j]
        sender_ttpb_iql[i, j] = running_steps
sender_ttpb_iql_mean = np.mean(sender_ttpb_iql, axis = 0)
sender_ttpb_iql_std= np.std(sender_ttpb_iql, axis = 0) / np.sqrt(NUM_RUNS)

sender_ttpb_iql_ir = np.load(RESULT_PATH + "iql_sender_time_to_pb_ir_argmax.npy")
for i in range(sender_ttpb_iql_ir.shape[0]):
    running_steps = 20
    for j in range(sender_ttpb_iql_ir.shape[1]):
        running_steps =  0.99 * running_steps + 0.01 * sender_ttpb_iql_ir[i, j]
        sender_ttpb_iql_ir[i, j] = running_steps
sender_ttpb_iql_ir_mean = np.mean(sender_ttpb_iql_ir, axis = 0)
sender_ttpb_iql_ir_std= np.std(sender_ttpb_iql_ir, axis = 0) / np.sqrt(NUM_RUNS)

sender_ttpb_obl = np.load(RESULT_PATH + "obl_sender_time_to_pb_argmax.npy")
for i in range(sender_ttpb_obl.shape[0]):
    running_steps = 20
    for j in range(sender_ttpb_obl.shape[1]):
        running_steps =  0.99 * running_steps + 0.01 * sender_ttpb_obl[i, j]
        sender_ttpb_obl[i, j] = running_steps
sender_ttpb_obl_mean = np.mean(sender_ttpb_obl, axis = 0)
sender_ttpb_obl_std= np.std(sender_ttpb_obl, axis = 0) / np.sqrt(NUM_RUNS)

sender_ttpb_obl_ir = np.load(RESULT_PATH + "obl_sender_time_to_pb_ir_argmax.npy")
for i in range(sender_ttpb_obl_ir.shape[0]):
    running_steps = 20
    for j in range(sender_ttpb_obl_ir.shape[1]):
        running_steps =  0.99 * running_steps + 0.01 * sender_ttpb_obl_ir[i, j]
        sender_ttpb_obl_ir[i, j] = running_steps
sender_ttpb_obl_ir_mean = np.mean(sender_ttpb_obl_ir, axis = 0)
sender_ttpb_obl_ir_std= np.std(sender_ttpb_obl_ir, axis = 0) / np.sqrt(NUM_RUNS)

sender_ttpb_obl_mi = np.load(RESULT_PATH + "obl_sender_time_to_pb_mi_log2_argmax.npy")
for i in range(sender_ttpb_obl_mi.shape[0]):
    running_steps = 20
    for j in range(sender_ttpb_obl_mi.shape[1]):
        running_steps =  0.99 * running_steps + 0.01 * sender_ttpb_obl_mi[i, j]
        sender_ttpb_obl_mi[i, j] = running_steps
sender_ttpb_obl_mi_mean = np.mean(sender_ttpb_obl_mi, axis = 0)
sender_ttpb_obl_mi_std= np.std(sender_ttpb_obl_mi, axis = 0) / np.sqrt(NUM_RUNS)

sender_ttpb_obl_mi_loss = np.load(RESULT_PATH + "obl_sender_time_to_pb_mi_loss_argmax.npy")
for i in range(sender_ttpb_obl_mi_loss.shape[0]):
    running_steps = 20
    for j in range(sender_ttpb_obl_mi_loss.shape[1]):
        running_steps =  0.99 * running_steps + 0.01 * sender_ttpb_obl_mi_loss[i, j]
        sender_ttpb_obl_mi_loss[i, j] = running_steps
sender_ttpb_obl_mi_loss_mean = np.mean(sender_ttpb_obl_mi_loss, axis = 0)
sender_ttpb_obl_mi_loss_std= np.std(sender_ttpb_obl_mi_loss, axis = 0) / np.sqrt(NUM_RUNS)

sender_ttpb_obl_mi_mi_loss = np.load(RESULT_PATH + "obl_sender_time_to_pb_mi_log2_mi_loss_argmax.npy")
for i in range(sender_ttpb_obl_mi_mi_loss.shape[0]):
    running_steps = 20
    for j in range(sender_ttpb_obl_mi_mi_loss.shape[1]):
        running_steps =  0.99 * running_steps + 0.01 * sender_ttpb_obl_mi_mi_loss[i, j]
        sender_ttpb_obl_mi_mi_loss[i, j] = running_steps
sender_ttpb_obl_mi_mi_loss_mean = np.mean(sender_ttpb_obl_mi_mi_loss, axis = 0)
sender_ttpb_obl_mi_mi_loss_std= np.std(sender_ttpb_obl_mi_mi_loss, axis = 0) / np.sqrt(NUM_RUNS)
#
# sender_ttpb_obl_mi_mi_loss_iql_util = np.load(RESULT_PATH + "obl_sender_time_to_pb_mi_log2_mi_loss_util_iql_argmax.npy")
# sender_ttpb_obl_mi_mi_loss_iql_util_mean = np.mean(sender_ttpb_obl_mi_mi_loss_iql_util, axis = 0)
#
# sender_ttpb_obl_dial_mi_mi_loss = np.load(RESULT_PATH + "obl_dial_sender_time_to_pb_mi_log2_mi_loss_argmax.npy")
# sender_ttpb_obl_dial_mi_mi_loss_mean = np.mean(sender_ttpb_obl_dial_mi_mi_loss, axis = 0)

# Plot Mean
plt.plot(episodes_indices, sender_ttpb_iql_mean.squeeze(), label = "IQL", color = color_list[-1])
# plt.fill_between(episodes_indices, sender_ttpb_iql_mean-sender_ttpb_iql_std, sender_ttpb_iql_mean+sender_ttpb_iql_std, facecolor = color_list[-1], alpha = 0.3)

plt.plot(episodes_indices, sender_ttpb_iql_ir_mean.squeeze(), label = "IQL + IR", color = color_list[-2])
# plt.fill_between(episodes_indices, sender_ttpb_iql_ir_mean-sender_ttpb_iql_ir_std, sender_ttpb_iql_ir_mean+sender_ttpb_iql_ir_std, facecolor = color_list[-2], alpha = 0.3)

plt.plot(episodes_indices, sender_ttpb_obl_mean.squeeze(), label = "OBL" , color = color_list[-3])
# plt.fill_between(episodes_indices, sender_ttpb_obl_mean-sender_ttpb_obl_std, sender_ttpb_obl_mean+sender_ttpb_obl_std, facecolor = color_list[-3], alpha = 0.3)

plt.plot(episodes_indices, sender_ttpb_obl_ir_mean.squeeze(), label = "OBL + IR" , color = color_list[-4])
# plt.fill_between(episodes_indices, sender_ttpb_obl_ir_mean-sender_ttpb_obl_ir_std, sender_ttpb_obl_ir_mean+sender_ttpb_obl_ir_std, facecolor = color_list[-4], alpha = 0.3)

# plt.plot(episodes_indices, sender_ttpb_obl_mi_mean.squeeze(), label = "OBL + MI Reward" , color = color_list[0])
# plt.fill_between(episodes_indices, sender_ttpb_obl_mi_mean-sender_ttpb_obl_mi_std, sender_ttpb_obl_mi_mean+sender_ttpb_obl_mi_std, facecolor = color_list[0], alpha = 0.3)
#
# plt.plot(episodes_indices, sender_ttpb_obl_mi_loss_mean.squeeze(), label = "OBL + MI Loss" , color = color_list[1])
# plt.fill_between(episodes_indices, sender_ttpb_obl_mi_loss_mean-sender_ttpb_obl_mi_loss_std, sender_ttpb_obl_mi_loss_mean+sender_ttpb_obl_mi_loss_std, facecolor = color_list[1], alpha = 0.3)
#
# plt.plot(episodes_indices, sender_ttpb_obl_mi_mi_loss_mean.squeeze(), label = "OBL + MI Reward + MI Loss" , color = color_list[2])
# plt.fill_between(episodes_indices, sender_ttpb_obl_mi_mi_loss_mean-sender_ttpb_obl_mi_mi_loss_std, sender_ttpb_obl_mi_mi_loss_mean+sender_ttpb_obl_mi_mi_loss_std, facecolor = color_list[2], alpha = 0.3)

ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# plt.plot(episodes_indices, sender_ttpb_obl_mi_mi_loss_mean.squeeze(), label = "OBL + MI + OBL Util")
# plt.plot(episodes_indices, sender_ttpb_obl_mi_mi_loss_iql_util_mean.squeeze(), label = "OBL + MI + IQL Util")
# plt.plot(episodes_indices, sender_ttpb_obl_dial_mi_mi_loss_mean.squeeze(), label = "OBL + MI + DIAL-IQL Util")
plt.legend()
plt.ylabel("Steps to Phone Booth")
plt.xlabel("Episodes")
# plt.title("Cheap Talk Discovery")
plt.title("Cheap Talk Discovery - Baselines")
plt.show()

# print(np.mean(sender_ttpb_iql_mean))
# print(np.mean(sender_ttpb_obl_mean))
# print(np.mean(sender_ttpb_obl_mi_mi_loss_mean))
# print(np.mean(sender_ttpb_obl_mi_mi_loss_iql_util_mean))
# print(np.mean(sender_ttpb_obl_dial_mi_mi_loss_mean))

# sender_ttpb_ir = np.load(RESULT_PATH + "sender_time_to_pb_mi_log2_argmax.npy")
# sender_ttpb_ir_mean = np.mean(sender_ttpb_ir, axis = 0)
# sender_ttpb = np.load(RESULT_PATH + "sender_time_to_pb_argmax.npy")
# sender_ttpb_mean = np.mean(sender_ttpb, axis = 0)
#
# # Plot Mean
# plt.plot(episodes_indices, sender_ttpb_ir_mean.squeeze(), label = "OBL with MI reward")
# plt.plot(episodes_indices, sender_ttpb_mean.squeeze(), label = "OBL")
# plt.legend()
# plt.ylabel("Steps to Phone Booth")
# plt.xlabel("Episodes")
# plt.show()
# print(sender_ttpb_ir.shape)
# print(np.mean(sender_ttpb_ir))
# print(np.mean(sender_ttpb))

# Plot every run
# for i in range(NUM_RUNS):
#     if(i == 0):
#         plt.plot(episodes_indices, sender_ttpb_ir[4].squeeze(), label = "OBL with MI reward", color = "blue")
#         plt.plot(episodes_indices, sender_ttpb[4].squeeze(), label = "OBL", color = "orange")
#     else:
#         plt.plot(episodes_indices, sender_ttpb_ir[i].squeeze(), label = "_", color = "blue")
#         plt.plot(episodes_indices, sender_ttpb[i].squeeze(), label = "_", color = "orange")
#     plt.legend()
#     plt.ylabel("Steps to Phone Booth")
#     plt.xlabel("Episodes")
#     plt.show()
# fig, axes = plt.subplots(3, 2)
# print(axes.shape)
# j = 0
# for i in range(5):
#     if(i == 0):
#         axes[j, i % 2].plot(episodes_indices, sender_ttpb_ir[i].squeeze(), label = "OBL with MI reward", color = "blue")
#         axes[j, i % 2].plot(episodes_indices, sender_ttpb[i].squeeze(), label = "OBL", color = "orange")
#     else:
#         axes[j, i % 2].plot(episodes_indices, sender_ttpb_ir[i].squeeze(), label = "_", color = "blue")
#         axes[j, i % 2].plot(episodes_indices, sender_ttpb[i].squeeze(), label = "_", color = "orange")
#     if((i % 2) == 1):
#         j += 1
#     print(j)
# plt.legend()
# plt.ylabel("Steps to Phone Booth")
# plt.xlabel("Episodes")
# plt.show()
